{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 最小生成树"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/weightedgraph.png\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 邻接表\n",
    "graph_dict = {0:{1:6,2:3},\n",
    "              1:{0:6,2:2,3:7,6:4},\n",
    "              2:{0:3,1:2,4:2,5:1,7:10},\n",
    "              3:{1:7},\n",
    "              4:{2:2,6:1},\n",
    "              5:{2:1,7:8},\n",
    "              6:{1:4,4:1},\n",
    "              7:{2:10,5:8}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 邻接矩阵\n",
    "inf = float('inf')\n",
    "graph_matrix = [[inf, 6, 3, inf, inf, inf, inf, inf],\n",
    "                [6, inf, 2, 7, inf, inf, 4, inf],\n",
    "                [3, 2, inf, inf, 2, 1, inf, 10],\n",
    "                [inf, 7, inf, inf, inf, inf, inf, inf],\n",
    "                [inf, inf, 2, inf, inf, inf, 1, inf],\n",
    "                [inf, inf, 1, inf, inf, inf, inf, 8],\n",
    "                [inf, 4, inf, inf, 1, inf, inf, inf],\n",
    "                [inf, inf, 10, inf, inf, 8, inf, inf],\n",
    "               ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from heapq import *\n",
    "\n",
    "# Prim算法（邻接表）\n",
    "def prim_dict(graph_dict):\n",
    "    if not graph_dict: return None\n",
    "    \n",
    "    size = len(graph_dict)\n",
    "    heap = []\n",
    "    visit = set()\n",
    "    new_node = list(graph_dict.keys())[0]\n",
    "    visit.add(new_node)\n",
    "    res = []\n",
    "    \n",
    "    while len(visit) < size:\n",
    "        # 将新边加入最小堆中\n",
    "        for adj_node, weight in graph_dict[new_node].items():\n",
    "            if adj_node not in visit: heappush(heap, (weight, new_node, adj_node))\n",
    "        # 从最小堆中寻找新的最短边\n",
    "        min_weight, old_node, new_node = heappop(heap)\n",
    "        while new_node in visit: min_weight, old_node, new_node = heappop(heap)\n",
    "        # 将找到的最短边添加到结果序列中\n",
    "        visit.add(new_node)\n",
    "        res.append((old_node, new_node, min_weight))\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 2, 3), (2, 5, 1), (2, 1, 2), (2, 4, 2), (4, 6, 1), (1, 3, 7), (5, 7, 8)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prim_dict(graph_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from heapq import *\n",
    "\n",
    "# Prim算法（邻接矩阵）\n",
    "def prim_matrix(graph_matrix):\n",
    "    if not graph_matrix or not graph_matrix[0]: return None\n",
    "    \n",
    "    size = len(graph_matrix)\n",
    "    heap = []\n",
    "    visit = set()\n",
    "    left = {i for i in range(size)}\n",
    "    new_node = 0\n",
    "    visit.add(new_node)\n",
    "    left.remove(new_node)\n",
    "    res = []\n",
    "    \n",
    "    while len(visit) < size:\n",
    "        # 将新边加入最小堆中\n",
    "        for adj_node in left:\n",
    "            if graph_matrix[new_node][adj_node] < float('inf'):\n",
    "                heappush(heap, (graph_matrix[new_node][adj_node], new_node, adj_node))\n",
    "        # 从最小堆中寻找新的最短边\n",
    "        min_weight, old_node, new_node = heappop(heap)\n",
    "        while new_node in visit: min_weight, old_node, new_node = heappop(heap)\n",
    "        # 将找到的最短边添加到结果序列中\n",
    "        visit.add(new_node)\n",
    "        left.remove(new_node)\n",
    "        res.append((old_node, new_node, min_weight))\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 2, 3), (2, 5, 1), (2, 1, 2), (2, 4, 2), (4, 6, 1), (1, 3, 7), (5, 7, 8)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prim_matrix(graph_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 并查集\n",
    "class Node(object):\n",
    "    def __init__(self, val=None, father=None):\n",
    "        self.val = val\n",
    "        self.father = father\n",
    "        \n",
    "    def __lt__(self, other):\n",
    "        return self.val < other.val\n",
    "    \n",
    "    def __eq__(self, other):\n",
    "        return self.val == other.val \n",
    "    \n",
    "def find(a):\n",
    "    if not a.father: return a\n",
    "    root = find(a.father)\n",
    "    a.father = root # 路径压缩\n",
    "    return root\n",
    "\n",
    "def same(a, b):\n",
    "    return find(a) == find(b)\n",
    "\n",
    "def union(a, b):\n",
    "    a = find(a)\n",
    "    b = find(b)\n",
    "    if a == b: return\n",
    "    a.father = b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from heapq import *\n",
    "\n",
    "# Kruskal算法（邻接表）\n",
    "def kruskal_dict(graph_dict):\n",
    "    if not graph_dict: return None\n",
    "    \n",
    "    heap = []\n",
    "    res = []\n",
    "    \n",
    "    # 初始化节点\n",
    "    node_dict = {}\n",
    "    for node in graph_dict:\n",
    "        node_dict[node] = Node(node)\n",
    "    \n",
    "    # 将所有的边加入最小堆\n",
    "    for node in graph_dict:\n",
    "        for adj_node, weight in graph_dict[node].items():\n",
    "            if (weight, node_dict[adj_node], node_dict[node]) not in heap:\n",
    "                heappush(heap, (weight, node_dict[node], node_dict[adj_node]))\n",
    "    \n",
    "    # 从最小堆中依次取出边\n",
    "    while heap:\n",
    "        weight, node1, node2 = heappop(heap)\n",
    "        # 如果两节点已经在同一棵树中，则不能连接，否则会出现环\n",
    "        if same(node1, node2): continue\n",
    "        # 连接两个节点\n",
    "        union(node1, node2)\n",
    "        # 将找到的最短边添加到结果序列中\n",
    "        res.append((node1.val, node2.val, weight))\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(2, 5, 1), (4, 6, 1), (1, 2, 2), (2, 4, 2), (0, 2, 3), (1, 3, 7), (5, 7, 8)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kruskal_dict(graph_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from heapq import *\n",
    "\n",
    "# Kruskal算法（邻接矩阵）\n",
    "def kruskal_matrix(graph_matrix):\n",
    "    if not graph_matrix or not graph_matrix[0]: return None\n",
    "    \n",
    "    size = len(graph_matrix)\n",
    "    heap = []\n",
    "    res = []\n",
    "    \n",
    "    # 初始化节点\n",
    "    node_dict = {}\n",
    "    for node in graph_dict:\n",
    "        node_dict[node] = Node(node)\n",
    "    \n",
    "    # 将所有的边加入最小堆\n",
    "    for i in range(size):\n",
    "        for j in range(i + 1, size):\n",
    "            if graph_matrix[i][j] < float('inf'): heappush(heap, (graph_matrix[i][j], node_dict[i], node_dict[j]))\n",
    "    \n",
    "    # 从最小堆中依次取出边\n",
    "    while heap:\n",
    "        weight, node1, node2 = heappop(heap)\n",
    "        # 如果两节点已经在同一棵树中，则不能连接，否则会出现环\n",
    "        if same(node1, node2): continue\n",
    "        # 连接两个节点\n",
    "        union(node1, node2)\n",
    "        # 将找到的最短边添加到结果序列中\n",
    "        res.append((node1.val, node2.val, weight))\n",
    "    \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(2, 5, 1), (4, 6, 1), (1, 2, 2), (2, 4, 2), (0, 2, 3), (1, 3, 7), (5, 7, 8)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kruskal_matrix(graph_matrix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 树型动态规划"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node(object):\n",
    "    def __init__(self, val=None,):\n",
    "        self.val = val\n",
    "        self.children = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 求把树划分为任何子节点都能直接访问到根节点的最少子树数量\n",
    "def least_sub_tree(root):\n",
    "    dp = {} # 记忆化搜索，储存已经计算过的节点\n",
    "    \n",
    "    def dfs(node):\n",
    "        if not node.children: return 1\n",
    "        if node in dp: return dp[node]\n",
    "        num = 1\n",
    "        for child in node.children:\n",
    "            # 所有的子节点分两种情况，一种是跟随自己，另一种是脱离自己，选最小的那种\n",
    "            num += min(sum([dfs(i) for i in child.children]), dfs(child))\n",
    "        dp[node] = num\n",
    "        return num\n",
    "    \n",
    "    return dfs(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = Node(0)\n",
    "for i in range(10):\n",
    "    child1 = Node(i)\n",
    "    root.children.append(child1)\n",
    "    for j in range(i + 1, 10):\n",
    "        child2 = Node(j)\n",
    "        child1.children.append(child2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<__main__.Node at 0x11173fb70>,\n",
       " <__main__.Node at 0x11173fda0>,\n",
       " <__main__.Node at 0x11173ff98>,\n",
       " <__main__.Node at 0x111741198>,\n",
       " <__main__.Node at 0x111741320>,\n",
       " <__main__.Node at 0x111741470>,\n",
       " <__main__.Node at 0x111741588>,\n",
       " <__main__.Node at 0x111741668>,\n",
       " <__main__.Node at 0x111741710>,\n",
       " <__main__.Node at 0x111741780>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root.children"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<__main__.Node at 0x1117414a8>,\n",
       " <__main__.Node at 0x1117414e0>,\n",
       " <__main__.Node at 0x111741518>,\n",
       " <__main__.Node at 0x111741550>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root.children[5].children"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "least_sub_tree(root)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3(py37)",
   "language": "python",
   "name": "py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
